package keras;

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Rule;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.resources.Resources;
import org.springframework.core.io.ClassPathResource;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;

public class ImportWgan {
    private static final String TEMP_OUTPUTS_FILENAME = "tempOutputs";
    private static final String TEMP_MODEL_FILENAME = "tempModel";
    private static final String H5_EXTENSION = ".h5";

    @Rule
    public final TemporaryFolder testDir = new TemporaryFolder();
  /*  @Test
    public void importWganDiscriminator() throws Exception {
        for (int i = 0; i < 100; i++) {
            // run a few times to make sure HDF5 doesn't crash
            importSequentialModelH5Test("modelimport/keras/examples/gans/wgan_discriminator.h5");
        }
    }*/
    @Test
    public void test() throws Exception {
        //String fullModel = new ClassPathResource("F:\\face\\model\\4full_modelg.h5").getFile().getPath();
        ComputationGraph model = KerasModelImport.importKerasModelAndWeights("F:\\face\\model\\4full_modelg.h5");
    }

    private MultiLayerNetwork importSequentialModelH5Test(String modelPath, int[] inputShape) throws Exception {
        try(InputStream is = Resources.asStream(modelPath)) {
            File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION);
            Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
            KerasModelBuilder builder = new KerasModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
                    .enforceTrainingConfig(false);
            if (inputShape != null) {
                builder.inputShape(inputShape);
            }
            KerasSequentialModel model = builder.buildSequential();
            return model.getMultiLayerNetwork();
        }
    }
    private File createTempFile(String prefix, String suffix) throws IOException {
        return testDir.newFile(prefix + "-" + System.nanoTime() + suffix);
    }
    private MultiLayerNetwork importSequentialModelH5Test(String modelPath) throws Exception {
        return importSequentialModelH5Test(modelPath, null);
    }

}
